Chapter 3: Sampling the Imaginary

[1]:
import random
from typing import Sequence

import arviz as az
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pandas as pd
import plotly
import plotly.graph_objects as go
import plotly.io as pio
from scipy.stats import gaussian_kde

pd.options.plotting.backend = "plotly"

seed = 84735
pio.templates.default = "plotly_white"
rng = jax.random.PRNGKey(seed)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Code

Code 3.1

[2]:
p_positive_vampire = 0.95
p_positive_mortal = 0.01
p_vampire = 0.001
p_positive = p_positive_vampire * p_vampire + p_positive_mortal * (1 - p_vampire)
p_vampire_positive = p_positive_vampire * p_vampire / p_positive
p_vampire_positive
[2]:
0.08683729433272395

Code 3.2

[3]:
def calculate_posterior(W: int, L: int, prior: Sequence[float], grid_size: int):
    grid = jnp.linspace(0, 1, grid_size)
    likelihood = jnp.exp(dist.Binomial(total_count=W + L, probs=grid).log_prob(W))
    raw_posterior = prior * likelihood
    posterior = raw_posterior / raw_posterior.sum()
    return posterior


W = 6
L = 3
grid_size = 1_000
prior = jnp.full(grid_size, 1)
p_grid = jnp.linspace(0, 1, grid_size)
posterior = calculate_posterior(W, L, prior, grid_size)

Code 3.3

[4]:
samples = p_grid[
    dist.Categorical(probs=posterior).sample(rng, (10_000,))
]

Code 3.4

[5]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=jnp.arange(10_000),
        y=samples,
        mode="markers",
        line={"color": "rgba(0, 0, 255, 0.2)"},
    )
)

Code 3.5

[6]:
az.plot_density({"": samples}, hdi_prob=1)
[6]:
array([[<AxesSubplot: >]], dtype=object)
../_images/notebooks_03_sampling_the_imaginary_11_1.png

Code 3.6

[7]:
posterior[p_grid < 0.5].sum()
[7]:
DeviceArray(0.17187458, dtype=float32)

Code 3.7

[8]:
jnp.sum(samples < 0.5) / samples.shape[0]
[8]:
DeviceArray(0.1756, dtype=float32)

Code 3.8

[9]:
jnp.sum(jnp.logical_and(samples > 0.5, samples < 0.75)) / samples.shape[0]
[9]:
DeviceArray(0.5978, dtype=float32)

Code 3.9

[10]:
jnp.quantile(samples, 0.8)
[10]:
DeviceArray(0.7617617, dtype=float32)

Code 3.10

[11]:
jnp.quantile(samples, jnp.array([0.1, 0.9]))
[11]:
DeviceArray([0.45245245, 0.8128128 ], dtype=float32)

Code 3.11

[12]:
posterior = calculate_posterior(W=3, L=0, prior=jnp.full(1_000, 1), grid_size=1_000)
samples = p_grid[dist.Categorical(probs=posterior).sample(rng, (10_000,))]

Code 3.12

[13]:
def percentile_interval(samples, prob):
    prob = min(prob, 1 - prob)
    return jnp.quantile(samples, jnp.array([prob / 2, 1 - prob / 2]))


percentile_interval(samples, 0.5)
[13]:
DeviceArray([0.7067067, 0.9319319], dtype=float32)

Code 3.13

[14]:
numpyro.diagnostics.hpdi(samples, prob=0.5)
[14]:
array([0.8398398, 0.998999 ], dtype=float32)

Code 3.14

[15]:
p_grid[jnp.argmax(posterior)]
[15]:
DeviceArray(1., dtype=float32)

Code 3.15

[16]:
samples[jnp.argmax(gaussian_kde(samples, bw_method=0.01)(samples))]
[16]:
DeviceArray(0.985986, dtype=float32)

Code 3.16

[17]:
display(samples.mean())
jnp.median(samples)
DeviceArray(0.8006291, dtype=float32)
[17]:
DeviceArray(0.8408408, dtype=float32)

Code 3.17

[18]:
jnp.sum(jnp.abs(0.5 - p_grid) * posterior)
[18]:
DeviceArray(0.31287518, dtype=float32)

Code 3.18

[19]:
loss = jax.vmap(lambda d: jnp.sum(jnp.abs(d - p_grid) * posterior))(p_grid)
display(pd.DataFrame(loss, index=p_grid).plot())

Code 3.19

[20]:
p_grid[jnp.argmin(loss)]
[20]:
DeviceArray(0.8408408, dtype=float32)

Code 3.20

[21]:
jnp.exp(dist.Binomial(total_count=2, probs=0.7).log_prob(jnp.arange(3)))
[21]:
DeviceArray([0.08999996, 0.42000008, 0.48999974], dtype=float32)

Code 3.21

[22]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample("dummy_w", dist.Binomial(total_count=2, probs=0.7))
dummy_w
[22]:
DeviceArray(2, dtype=int32, weak_type=True)

Code 3.22

[23]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(10,)
    )
dummy_w
[23]:
DeviceArray([0, 1, 2, 1, 1, 2, 1, 2, 1, 2], dtype=int32, weak_type=True)

Code 3.23

[24]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=2, probs=0.7), sample_shape=(100_000,)
    )
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w["freq"] = 1
dummy_w.groupby("dummy_w").sum() / 100_000
[24]:
freq
dummy_w
0 0.09004
1 0.42109
2 0.48887

Code 3.24

[25]:
with numpyro.handlers.seed(rng_seed=seed):
    dummy_w = numpyro.sample(
        "dummy_w", dist.Binomial(total_count=9, probs=0.7), sample_shape=(100_000,)
    )
dummy_w = pd.DataFrame(dummy_w, columns=["dummy_w"])
dummy_w.plot(kind="hist")

Code 3.25

[26]:
w = dist.Binomial(total_count=9, probs=0.6).sample(jax.random.PRNGKey(seed), (10_000,))
pd.DataFrame(w).plot(kind="hist")

Code 3.26

[27]:
w = dist.Binomial(total_count=9, probs=samples).sample(
    jax.random.PRNGKey(seed),
)
pd.DataFrame(w).plot(kind="hist")

Hard

3H1

[28]:
# fmt: off
births_1 = [
    1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
    0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0,
    0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,
    0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
]
births_2 = [
    0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1,
    1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,
    0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
]
births = jnp.array([births_1, births_2])
[29]:
grid_size = 1_000

p_grid = jnp.linspace(0, 1, grid_size)
prior = [0.5] * grid_size
likelhihood = jnp.exp(
    dist.Binomial(total_count=births.size, probs=p_grid).log_prob(births.sum())
)
raw_posterior = likelhihood * jnp.array(prior)
posterior = raw_posterior / raw_posterior.sum()
map_p = p_grid[jnp.argmax(posterior)]
print(f"p={map_p:.2%} maximizes the posterior probability.")
p=55.46% maximizes the posterior probability.

3H2

[30]:
posterior_samples = p_grid[dist.Categorical(probs=posterior).sample(rng, (10_000,))]
print(f"50% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.5)}")
print(f"89% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.89)}")
print(f"97% HDPI: {numpyro.diagnostics.hpdi(posterior_samples, prob=0.97)}")
50% HDPI: [0.5275275 0.5745746]
89% HDPI: [0.4994995 0.6096096]
97% HDPI: [0.47847846 0.6286286 ]

3H3

[31]:
posterior_predictive_samples = dist.Binomial(
    total_count=births.size, probs=posterior_samples
).sample(rng)
print(
    f"Posterior predictive distribution of number of boys has mean {posterior_predictive_samples.mean():.0f} "
    f"vs observation of {births.sum()}: we're evaluating model against training data"
)
pd.DataFrame(posterior_predictive_samples, columns=["n_boys"]).plot(kind="hist")
Posterior predictive distribution of number of boys has mean 111 vs observation of 111: we're evaluating model against training data

3H4

[32]:
posterior_predictive_samples = dist.Binomial(
    total_count=births.shape[1], probs=posterior_samples
).sample(rng)
print(
    f"Posterior predictive distribution of first born sons has mean {posterior_predictive_samples.mean():.0f} "
    f"vs obersvation of {births[0].sum()}; still reasonable but not as good as purely 'in-sample'."
)
pd.DataFrame(posterior_predictive_samples, columns=["n_first_born_boys"]).plot(
    kind="hist"
)
Posterior predictive distribution of first born sons has mean 55 vs obersvation of 51; still reasonable but not as good as purely 'in-sample'.

3H5

[33]:
posterior_predictive_samples = dist.Binomial(
    total_count=jnp.logical_not(births[0]).sum(), probs=posterior_samples
).sample(rng)
print(
    f"PPD of boys with big sisters of {posterior_predictive_samples.mean():.0f} "
    f"is completely out of line with observations of {births[1].sum()}: we didn't model "
    "the correlation between first and second birth that's present in our dataset."
)
pd.DataFrame(posterior_predictive_samples, columns=["n_boys_with_big_sister"]).plot(
    kind="hist"
)
PPD of boys with big sisters of 27 is completely out of line with observations of 60: we didn't model the correlation between first and second birth that's present in our dataset.
[ ]: